{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Tce3stUlHN0L"
      },
      "source": [
        "##### Copyright 2019 The TensorFlow Authors.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "tuOe1ymfHZPu"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "MfBg1C5NB3X0"
      },
      "source": [
        "# Custom training loop with Keras and MultiWorkerMirroredStrategy\n",
        "\n",
        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://www.tensorflow.org/tutorials/distribute/multi_worker_with_ctl\"><img src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" />View on TensorFlow.org</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/docs/blob/master/site/en/tutorials/distribute/multi_worker_with_ctl.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://github.com/tensorflow/docs/blob/master/site/en/tutorials/distribute/multi_worker_with_ctl.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View source on GitHub</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a href=\"https://storage.googleapis.com/tensorflow_docs/docs/site/en/tutorials/distribute/multi_worker_with_ctl.ipynb\"><img src=\"https://www.tensorflow.org/images/download_logo_32px.png\" />Download notebook</a>\n",
        "  </td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "xHxb-dlhMIzW"
      },
      "source": [
        "## Overview\n",
        "\n",
        "This tutorial demonstrates multi-worker training with custom training loop API, distributed via MultiWorkerMirroredStrategy, so a Keras model designed to run on [single-worker](https://www.tensorflow.org/tutorials/distribute/custom_training) can seamlessly work on multiple workers with minimal code change.\n",
        "\n",
        "We are using custom training loops to train our model because they give us flexibility and a greater control on training. Moreover, it is easier to debug the model and the training loop. More detailed information is available in [Writing a training loop from scratch](https://www.tensorflow.org/guide/keras/writing_a_training_loop_from_scratch).\n",
        "\n",
        "If you are looking for how to use `MultiWorkerMirroredStrategy` with keras `model.fit`, refer to this [tutorial](https://www.tensorflow.org/tutorials/distribute/multi_worker_with_keras) instead.\n",
        "\n",
        "[Distributed Training in TensorFlow](../../guide/distributed_training.ipynb) guide is available for an overview of the distribution strategies TensorFlow supports for those interested in a deeper understanding of `tf.distribute.Strategy` APIs."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "MUXex9ctTuDB"
      },
      "source": [
        "## Setup\n",
        "\n",
        "First, some necessary imports."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "bnYxvfLD-LW-"
      },
      "outputs": [],
      "source": [
        "import json\n",
        "import os\n",
        "import sys"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Zz0EY91y3mxy"
      },
      "source": [
        "Before importing TensorFlow, make a few changes to the environment.\n",
        "\n",
        "Disable all GPUs. This prevents errors caused by the workers all trying to use the same GPU. For a real application each worker would be on a different machine."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "685pbYEY3jGC"
      },
      "outputs": [],
      "source": [
        "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"-1\""
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "7X1MS6385BWi"
      },
      "source": [
        "Reset the `TF_CONFIG` environment variable, you'll see more about this later."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "WEJLYa2_7OZF"
      },
      "outputs": [],
      "source": [
        "os.environ.pop('TF_CONFIG', None)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Rd4L9Ii77SS8"
      },
      "source": [
        "Be sure that the current directory is on python's path. This allows the notebook to import the files written by `%%writefile` later.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "hPBuZUNSZmrQ"
      },
      "outputs": [],
      "source": [
        "if '.' not in sys.path:\n",
        "  sys.path.insert(0, '.')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "pDhHuMjb7bfU"
      },
      "source": [
        "Now import TensorFlow."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "vHNvttzV43sA"
      },
      "outputs": [],
      "source": [
        "import tensorflow as tf"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0S2jpf6Sx50i"
      },
      "source": [
        "### Dataset and model definition"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "fLW6D2TzvC-4"
      },
      "source": [
        "Next create an `mnist.py` file with a simple model and dataset setup. This python file will be used by the worker-processes in this tutorial:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "dma_wUAxZqo2"
      },
      "outputs": [],
      "source": [
        "%%writefile mnist.py\n",
        "\n",
        "import os\n",
        "import tensorflow as tf\n",
        "import numpy as np\n",
        "\n",
        "def mnist_dataset(batch_size):\n",
        "  (x_train, y_train), _ = tf.keras.datasets.mnist.load_data()\n",
        "  # The `x` arrays are in uint8 and have values in the range [0, 255].\n",
        "  # You need to convert them to float32 with values in the range [0, 1]\n",
        "  x_train = x_train / np.float32(255)\n",
        "  y_train = y_train.astype(np.int64)\n",
        "  train_dataset = tf.data.Dataset.from_tensor_slices(\n",
        "      (x_train, y_train)).shuffle(60000)\n",
        "  return train_dataset\n",
        "\n",
        "def dataset_fn(global_batch_size, input_context):\n",
        "  batch_size = input_context.get_per_replica_batch_size(global_batch_size)\n",
        "  dataset = mnist_dataset(batch_size)\n",
        "  dataset = dataset.shard(input_context.num_input_pipelines,\n",
        "                          input_context.input_pipeline_id)\n",
        "  dataset = dataset.batch(batch_size)\n",
        "  return dataset\n",
        "\n",
        "def build_cnn_model():\n",
        "  return tf.keras.Sequential([\n",
        "      tf.keras.Input(shape=(28, 28)),\n",
        "      tf.keras.layers.Reshape(target_shape=(28, 28, 1)),\n",
        "      tf.keras.layers.Conv2D(32, 3, activation='relu'),\n",
        "      tf.keras.layers.Flatten(),\n",
        "      tf.keras.layers.Dense(128, activation='relu'),\n",
        "      tf.keras.layers.Dense(10)\n",
        "  ])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "JmgZwwymxqt5"
      },
      "source": [
        "## Multi-worker Configuration\n",
        "\n",
        "Now let's enter the world of multi-worker training. In TensorFlow, the `TF_CONFIG` environment variable is required for training on multiple machines, each of which possibly has a different role. `TF_CONFIG` used below, is a JSON string used to specify the cluster configuration on each worker that is part of the cluster. This is the default method for specifying a cluster, using `cluster_resolver.TFConfigClusterResolver`,  but there are other options available in the `distribute.cluster_resolver` module."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "SS8WhvRhe_Ya"
      },
      "source": [
        "### Describe your cluster\n",
        "Here is an example configuration:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "XK1eTYvSZiX7"
      },
      "outputs": [],
      "source": [
        "tf_config = {\n",
        "    'cluster': {\n",
        "        'worker': ['localhost:12345', 'localhost:23456']\n",
        "    },\n",
        "    'task': {'type': 'worker', 'index': 0}\n",
        "}"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "JjgwJbPKZkJL"
      },
      "source": [
        "Here is the same `TF_CONFIG` serialized as a JSON string:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "yY-T0YDQZjbu"
      },
      "outputs": [],
      "source": [
        "json.dumps(tf_config)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "AUBmYRZqxthH"
      },
      "source": [
        "There are two components of `TF_CONFIG`: `cluster` and `task`.\n",
        "\n",
        "* `cluster` is the same for all workers and provides information about the training cluster, which is a dict consisting of different types of jobs such as `worker`. In multi-worker training with `MultiWorkerMirroredStrategy`, there is usually one `worker` that takes on a little more responsibility like saving checkpoint and writing summary file for TensorBoard in addition to what a regular `worker` does. Such a worker is referred to as the `chief` worker, and it is customary that the `worker` with `index` 0 is appointed as the chief `worker` (in fact this is how `tf.distribute.Strategy` is implemented).\n",
        "\n",
        "* `task` provides information of the current task and is different on each worker. It specifies the `type` and `index` of that worker."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "8YFpxrcsZ2xG"
      },
      "source": [
        "In this example, you set the task `type` to `\"worker\"` and the task `index` to `0`. This machine is the first worker and will be appointed as the chief worker and do more work than the others. Note that other machines will need to have the `TF_CONFIG` environment variable set as well, and it should have the same `cluster` dict, but different task `type` or task `index` depending on what the roles of those machines are.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "aogb74kHxynz"
      },
      "source": [
        "For illustration purposes, this tutorial shows how one may set a `TF_CONFIG` with 2 workers on `localhost`.  In practice, users would create multiple workers on external IP addresses/ports, and set `TF_CONFIG` on each worker appropriately.\n",
        "\n",
        "In this example you will use 2 workers, the first worker's `TF_CONFIG` is shown above. For the second worker you would set `tf_config['task']['index']=1`"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "f83FVYqDX3aX"
      },
      "source": [
        "Above, `tf_config` is just a local variable in python. To actually use it to configure training, this dictionary needs to be serialized as JSON, and placed in the `TF_CONFIG` environment variable."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "cIlkfWmjz1PG"
      },
      "source": [
        "### Environment variables and subprocesses in notebooks"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "FcjAbuGY1ACJ"
      },
      "source": [
        "Subprocesses inherit environment variables from their parent. So if you set an environment variable in this `jupyter notebook` process:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "PH2gHn2_0_U8"
      },
      "outputs": [],
      "source": [
        "os.environ['GREETINGS'] = 'Hello TensorFlow!'"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "gQkIX-cg18md"
      },
      "source": [
        "You can access the environment variable from a subprocesses:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "pquKO6IA18G5"
      },
      "outputs": [],
      "source": [
        "%%bash\n",
        "echo ${GREETINGS}"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "af6BCA-Y2fpz"
      },
      "source": [
        "In the next section, you'll use this to pass the `TF_CONFIG` to the worker subprocesses. You would never really launch your jobs this way, but it's sufficient for the purposes of this tutorial: To demonstrate a minimal multi-worker example."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "UhNtHfuxCGVy"
      },
      "source": [
        "## MultiWorkerMirroredStrategy\n",
        "\n",
        "To train the model, use an instance of `tf.distribute.MultiWorkerMirroredStrategy`, which creates copies of all variables in the model's layers on each device across all workers. The [`tf.distribute.Strategy` guide](../../guide/distributed_training.ipynb) has more details about this strategy."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "1uFSHCJXMrQ-"
      },
      "outputs": [],
      "source": [
        "strategy = tf.distribute.MultiWorkerMirroredStrategy()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "N0iv7SyyAohc"
      },
      "source": [
        "Note: `TF_CONFIG` is parsed and TensorFlow's GRPC servers are started at the time `MultiWorkerMirroredStrategy()` is called, so the `TF_CONFIG` environment variable must be set before a `tf.distribute.Strategy` instance is created."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "TS4S-faBHHam"
      },
      "source": [
        "Use `tf.distribute.Strategy.scope` to specify that a strategy should be used when building your model. This puts you in the \"[cross-replica context](https://www.tensorflow.org/guide/distributed_training?hl=en#mirroredstrategy)\" for this strategy, which means the strategy is put in control of things like variable placement."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "nXV49tG1_opc"
      },
      "outputs": [],
      "source": [
        "import mnist\n",
        "with strategy.scope():\n",
        "  # Model building needs to be within `strategy.scope()`.\n",
        "  multi_worker_model = mnist.build_cnn_model()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "DSYkM-on6r3Y"
      },
      "source": [
        "## Auto-shard your data across workers\n",
        "In multi-worker training, dataset sharding is not necessarily needed, however it gives you exactly once semantic which makes more training more reproducible, i.e. training on multiple workers should be the same as training on one worker. Note: performance can be affected in some cases.\n",
        "\n",
        "See: [`distribute_datasets_from_function`](https://www.tensorflow.org/api_docs/python/tf/distribute/Strategy?version=nightly#distribute_datasets_from_function)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "65-p36pt6rUF"
      },
      "outputs": [],
      "source": [
        "per_worker_batch_size = 64\n",
        "num_workers = len(tf_config['cluster']['worker'])\n",
        "global_batch_size = per_worker_batch_size * num_workers\n",
        "\n",
        "with strategy.scope():\n",
        "  multi_worker_dataset = strategy.distribute_datasets_from_function(\n",
        "      lambda input_context: mnist.dataset_fn(global_batch_size, input_context))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "rkNzSR3g60iP"
      },
      "source": [
        "## Define Custom Training Loop and Train the model\n",
        "Specify an optimizer"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "NoMr4_zTeKSn"
      },
      "outputs": [],
      "source": [
        "with strategy.scope():\n",
        "  # The creation of optimizer and train_accuracy will need to be in\n",
        "  # `strategy.scope()` as well, since they create variables.\n",
        "  optimizer = tf.keras.optimizers.RMSprop(learning_rate=0.001)\n",
        "  train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(\n",
        "      name='train_accuracy')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "RmrDcAii4B5O"
      },
      "source": [
        "Define a training step with `tf.function`\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "znXWN5S3eUDB"
      },
      "outputs": [],
      "source": [
        "@tf.function\n",
        "def train_step(iterator):\n",
        "  \"\"\"Training step function.\"\"\"\n",
        "\n",
        "  def step_fn(inputs):\n",
        "    \"\"\"Per-Replica step function.\"\"\"\n",
        "    x, y = inputs\n",
        "    with tf.GradientTape() as tape:\n",
        "      predictions = multi_worker_model(x, training=True)\n",
        "      per_batch_loss = tf.keras.losses.SparseCategoricalCrossentropy(\n",
        "          from_logits=True,\n",
        "          reduction=tf.keras.losses.Reduction.NONE)(y, predictions)\n",
        "      loss = tf.nn.compute_average_loss(\n",
        "          per_batch_loss, global_batch_size=global_batch_size)\n",
        "\n",
        "    grads = tape.gradient(loss, multi_worker_model.trainable_variables)\n",
        "    optimizer.apply_gradients(\n",
        "        zip(grads, multi_worker_model.trainable_variables))\n",
        "    train_accuracy.update_state(y, predictions)\n",
        "    return loss\n",
        "\n",
        "  per_replica_losses = strategy.run(step_fn, args=(next(iterator),))\n",
        "  return strategy.reduce(\n",
        "      tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "eFXHsUVBy0Rx"
      },
      "source": [
        "### Checkpoint saving and restoring\n",
        "\n",
        "Checkpointing implementation in a Custom Training Loop requires the user to handle it instead of using a keras callback. It allows you to save model's weights and restore them without having to save the whole model."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "LcFO6x1KyjhI"
      },
      "outputs": [],
      "source": [
        "from multiprocessing import util\n",
        "checkpoint_dir = os.path.join(util.get_temp_dir(), 'ckpt')\n",
        "\n",
        "def _is_chief(task_type, task_id):\n",
        "  return task_type is None or task_type == 'chief' or (task_type == 'worker' and\n",
        "                                                       task_id == 0)\n",
        "def _get_temp_dir(dirpath, task_id):\n",
        "  base_dirpath = 'workertemp_' + str(task_id)\n",
        "  temp_dir = os.path.join(dirpath, base_dirpath)\n",
        "  tf.io.gfile.makedirs(temp_dir)\n",
        "  return temp_dir\n",
        "\n",
        "def write_filepath(filepath, task_type, task_id):\n",
        "  dirpath = os.path.dirname(filepath)\n",
        "  base = os.path.basename(filepath)\n",
        "  if not _is_chief(task_type, task_id):\n",
        "    dirpath = _get_temp_dir(dirpath, task_id)\n",
        "  return os.path.join(dirpath, base)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "P7fabUIEW7-M"
      },
      "source": [
        "Note: Checkpointing and Saving need to happen on each worker and they need to write to different paths as they would override each others.\n",
        "If you chose to only checkpoint/save on the chief, this can lead to deadlock and is not recommended."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "nrcdPHtG4ObO"
      },
      "source": [
        " Here, you'll create one `tf.train.Checkpoint` that tracks the model, which is managed by a `tf.train.CheckpointManager` so that only the latest checkpoint is preserved."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "4rURT2pI4aqV"
      },
      "outputs": [],
      "source": [
        "epoch = tf.Variable(\n",
        "    initial_value=tf.constant(0, dtype=tf.dtypes.int64), name='epoch')\n",
        "step_in_epoch = tf.Variable(\n",
        "    initial_value=tf.constant(0, dtype=tf.dtypes.int64),\n",
        "    name='step_in_epoch')\n",
        "task_type, task_id = (strategy.cluster_resolver.task_type,\n",
        "                      strategy.cluster_resolver.task_id)\n",
        "\n",
        "checkpoint = tf.train.Checkpoint(\n",
        "    model=multi_worker_model, epoch=epoch, step_in_epoch=step_in_epoch)\n",
        "\n",
        "write_checkpoint_dir = write_filepath(checkpoint_dir, task_type, task_id)\n",
        "checkpoint_manager = tf.train.CheckpointManager(\n",
        "    checkpoint, directory=write_checkpoint_dir, max_to_keep=1)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "RO7cbN40XD5v"
      },
      "source": [
        "Now, when you need to restore, you can find the latest checkpoint saved using the convenient `tf.train.latest_checkpoint` function."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "gniynaQj6HMV"
      },
      "outputs": [],
      "source": [
        "latest_checkpoint = tf.train.latest_checkpoint(checkpoint_dir)\n",
        "if latest_checkpoint:\n",
        "  checkpoint.restore(latest_checkpoint)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "1j9JuI-h6ObW"
      },
      "source": [
        "After restoring the checkpoint, you can continue with training your custom training loop."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "kZzXZCh45FY6"
      },
      "outputs": [],
      "source": [
        "num_epochs = 3\n",
        "num_steps_per_epoch = 70\n",
        "\n",
        "while epoch.numpy() < num_epochs:\n",
        "  iterator = iter(multi_worker_dataset)\n",
        "  total_loss = 0.0\n",
        "  num_batches = 0\n",
        "\n",
        "  while step_in_epoch.numpy() < num_steps_per_epoch:\n",
        "    total_loss += train_step(iterator)\n",
        "    num_batches += 1\n",
        "    step_in_epoch.assign_add(1)\n",
        "\n",
        "  train_loss = total_loss / num_batches\n",
        "  print('Epoch: %d, accuracy: %f, train_loss: %f.'\n",
        "                %(epoch.numpy(), train_accuracy.result(), train_loss))\n",
        "\n",
        "  train_accuracy.reset_states()\n",
        "\n",
        "  # Once the `CheckpointManager` is set up, you're now ready to save, and remove\n",
        "  # the checkpoints non-chief workers saved.\n",
        "  checkpoint_manager.save()\n",
        "  if not _is_chief(task_type, task_id):\n",
        "    tf.io.gfile.rmtree(write_checkpoint_dir)\n",
        "\n",
        "  epoch.assign_add(1)\n",
        "  step_in_epoch.assign(0)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0W1Osks466DE"
      },
      "source": [
        "## Full code setup on workers"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "jfYpmIxO6Jck"
      },
      "source": [
        "To actually run with `MultiWorkerMirroredStrategy` you'll need to run worker processes and pass a `TF_CONFIG` to them.\n",
        "\n",
        "Like the `mnist.py` file written earlier, here is the `main.py` that \n",
        "contain the same code we walked through step by step previously in this colab, we're just writing it to a file so each of the workers will run it:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "MIDCESkVzN6M"
      },
      "outputs": [],
      "source": [
        "%%writefile main.py\n",
        "#@title File: `main.py`\n",
        "import os\n",
        "import json\n",
        "import tensorflow as tf\n",
        "import mnist\n",
        "from multiprocessing import util\n",
        "\n",
        "per_worker_batch_size = 64\n",
        "tf_config = json.loads(os.environ['TF_CONFIG'])\n",
        "num_workers = len(tf_config['cluster']['worker'])\n",
        "global_batch_size = per_worker_batch_size * num_workers\n",
        "\n",
        "num_epochs = 3\n",
        "num_steps_per_epoch=70\n",
        "\n",
        "# Checkpoint saving and restoring\n",
        "def _is_chief(task_type, task_id):\n",
        "  return task_type is None or task_type == 'chief' or (task_type == 'worker' and\n",
        "                                                       task_id == 0)\n",
        "def _get_temp_dir(dirpath, task_id):\n",
        "  base_dirpath = 'workertemp_' + str(task_id)\n",
        "  temp_dir = os.path.join(dirpath, base_dirpath)\n",
        "  tf.io.gfile.makedirs(temp_dir)\n",
        "  return temp_dir\n",
        "\n",
        "def write_filepath(filepath, task_type, task_id):\n",
        "  dirpath = os.path.dirname(filepath)\n",
        "  base = os.path.basename(filepath)\n",
        "  if not _is_chief(task_type, task_id):\n",
        "    dirpath = _get_temp_dir(dirpath, task_id)\n",
        "  return os.path.join(dirpath, base)\n",
        "\n",
        "checkpoint_dir = os.path.join(util.get_temp_dir(), 'ckpt')\n",
        "\n",
        "# Define Strategy\n",
        "strategy = tf.distribute.MultiWorkerMirroredStrategy()\n",
        "\n",
        "with strategy.scope():\n",
        "  # Model building/compiling need to be within `strategy.scope()`.\n",
        "  multi_worker_model = mnist.build_cnn_model()\n",
        "\n",
        "  multi_worker_dataset = strategy.distribute_datasets_from_function(\n",
        "      lambda input_context: mnist.dataset_fn(global_batch_size, input_context))        \n",
        "  optimizer = tf.keras.optimizers.RMSprop(learning_rate=0.001)\n",
        "  train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(\n",
        "      name='train_accuracy')\n",
        "\n",
        "@tf.function\n",
        "def train_step(iterator):\n",
        "  \"\"\"Training step function.\"\"\"\n",
        "\n",
        "  def step_fn(inputs):\n",
        "    \"\"\"Per-Replica step function.\"\"\"\n",
        "    x, y = inputs\n",
        "    with tf.GradientTape() as tape:\n",
        "      predictions = multi_worker_model(x, training=True)\n",
        "      per_batch_loss = tf.keras.losses.SparseCategoricalCrossentropy(\n",
        "          from_logits=True,\n",
        "          reduction=tf.keras.losses.Reduction.NONE)(y, predictions)\n",
        "      loss = tf.nn.compute_average_loss(\n",
        "          per_batch_loss, global_batch_size=global_batch_size)\n",
        "\n",
        "    grads = tape.gradient(loss, multi_worker_model.trainable_variables)\n",
        "    optimizer.apply_gradients(\n",
        "        zip(grads, multi_worker_model.trainable_variables))\n",
        "    train_accuracy.update_state(y, predictions)\n",
        "\n",
        "    return loss\n",
        "\n",
        "  per_replica_losses = strategy.run(step_fn, args=(next(iterator),))\n",
        "  return strategy.reduce(\n",
        "      tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None)\n",
        "\n",
        "epoch = tf.Variable(\n",
        "    initial_value=tf.constant(0, dtype=tf.dtypes.int64), name='epoch')\n",
        "step_in_epoch = tf.Variable(\n",
        "    initial_value=tf.constant(0, dtype=tf.dtypes.int64),\n",
        "    name='step_in_epoch')\n",
        "\n",
        "task_type, task_id = (strategy.cluster_resolver.task_type,\n",
        "                      strategy.cluster_resolver.task_id)\n",
        "\n",
        "checkpoint = tf.train.Checkpoint(\n",
        "    model=multi_worker_model, epoch=epoch, step_in_epoch=step_in_epoch)\n",
        "\n",
        "write_checkpoint_dir = write_filepath(checkpoint_dir, task_type, task_id)\n",
        "checkpoint_manager = tf.train.CheckpointManager(\n",
        "    checkpoint, directory=write_checkpoint_dir, max_to_keep=1)\n",
        "\n",
        "# Restoring the checkpoint\n",
        "latest_checkpoint = tf.train.latest_checkpoint(checkpoint_dir)\n",
        "if latest_checkpoint:\n",
        "  checkpoint.restore(latest_checkpoint)\n",
        "\n",
        "# Resume our CTL training\n",
        "while epoch.numpy() < num_epochs:\n",
        "  iterator = iter(multi_worker_dataset)\n",
        "  total_loss = 0.0\n",
        "  num_batches = 0\n",
        "\n",
        "  while step_in_epoch.numpy() < num_steps_per_epoch:\n",
        "    total_loss += train_step(iterator)\n",
        "    num_batches += 1\n",
        "    step_in_epoch.assign_add(1)\n",
        "\n",
        "  train_loss = total_loss / num_batches\n",
        "  print('Epoch: %d, accuracy: %f, train_loss: %f.'\n",
        "                %(epoch.numpy(), train_accuracy.result(), train_loss))\n",
        "  \n",
        "  train_accuracy.reset_states()\n",
        "\n",
        "  checkpoint_manager.save()\n",
        "  if not _is_chief(task_type, task_id):\n",
        "    tf.io.gfile.rmtree(write_checkpoint_dir)\n",
        "\n",
        "  epoch.assign_add(1)\n",
        "  step_in_epoch.assign(0)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ItVOvPN1qnZ6"
      },
      "source": [
        "## Train and Evaluate\n",
        "The current directory now contains both Python files:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "bi6x05Sr60O9"
      },
      "outputs": [],
      "source": [
        "%%bash\n",
        "ls *.py"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "qmEEStPS6vR_"
      },
      "source": [
        "So json-serialize the `TF_CONFIG` and add it to the environment variables:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "9uu3g7vV7Bbt"
      },
      "outputs": [],
      "source": [
        "os.environ['TF_CONFIG'] = json.dumps(tf_config)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "MsY3dQLK7jdf"
      },
      "source": [
        "Now, you can launch a worker process that will run the `main.py` and use the `TF_CONFIG`:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "txMXaq8d8N_S"
      },
      "outputs": [],
      "source": [
        "# first kill any previous runs\n",
        "%killbgscripts"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "qnSma_Ck7r-r"
      },
      "outputs": [],
      "source": [
        "%%bash --bg\n",
        "python main.py &> job_0.log"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ZChyazqS7v0P"
      },
      "source": [
        "There are a few things to note about the above command:\n",
        "\n",
        "1. It uses the `%%bash` which is a [notebook \"magic\"](https://ipython.readthedocs.io/en/stable/interactive/magics.html) to run some bash commands.\n",
        "2. It uses the `--bg` flag to run the `bash` process in the background, because this worker will not terminate. It waits for all the workers before it starts.\n",
        "\n",
        "The backgrounded worker process won't print output to this notebook, so the `&>` redirects its output to a file, so you can see what happened.\n",
        "\n",
        "So, wait a few seconds for the process to start up:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Hm2yrULE9281"
      },
      "outputs": [],
      "source": [
        "import time\n",
        "time.sleep(20)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ZFPoNxg_9_Mx"
      },
      "source": [
        "Now look what's been output to the worker's logfile so far:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "vZEOuVgQ9-hn"
      },
      "outputs": [],
      "source": [
        "%%bash\n",
        "cat job_0.log"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "RqZhVF7L_KOy"
      },
      "source": [
        "The last line of the log file should say: `Started server with target: grpc://localhost:12345`. The first worker is now ready, and is waiting for all the other worker(s) to be ready to proceed."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Pi8vPNNA_l4a"
      },
      "source": [
        "So update the `tf_config` for the second worker's process to pick up:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "lAiYkkPu_Jqd"
      },
      "outputs": [],
      "source": [
        "tf_config['task']['index'] = 1\n",
        "os.environ['TF_CONFIG'] = json.dumps(tf_config)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0AshGVO0_x0w"
      },
      "source": [
        "Now launch the second worker. This will start the training since all the workers are active (so there's no need to background this process):"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "_ESVtyQ9_xjx"
      },
      "outputs": [],
      "source": [
        "%%bash\n",
        "python main.py > /dev/null 2>&1"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "hX4FA2O2AuAn"
      },
      "source": [
        "Now if you recheck the logs written by the first worker you'll see that it participated in training that model:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "rc6hw3yTBKXX"
      },
      "outputs": [],
      "source": [
        "%%bash\n",
        "cat job_0.log"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "sG5_1UgrgniF"
      },
      "outputs": [],
      "source": [
        "# Delete the `TF_CONFIG`, and kill any background tasks so they don't affect the next section.\n",
        "os.environ.pop('TF_CONFIG', None)\n",
        "%killbgscripts"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "bhxMXa0AaZkK"
      },
      "source": [
        "## Multi worker training in depth\n",
        "\n",
        "This tutorial has demonstrated a `Custom Training Loop` workflow of the multi-worker setup. A detailed description of other topics is available in the [`model.fit's guide`](https://colab.sandbox.google.com/github/tensorflow/docs/blob/master/site/en/tutorials/distribute/multi_worker_with_keras.ipynb) of the multi-worker setup and applicable to CTLs."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ega2hdOQEmy_"
      },
      "source": [
        "## See also\n",
        "1. [Distributed Training in TensorFlow](https://www.tensorflow.org/guide/distributed_training) guide provides an overview of the available distribution strategies.\n",
        "2. [Official models](https://github.com/tensorflow/models/tree/master/official), many of which can be configured to run multiple distribution strategies.\n",
        "3. The [Performance section](../../guide/function.ipynb) in the guide provides information about other strategies and [tools](../../guide/profiler.md) you can use to optimize the performance of your TensorFlow models.\n"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "name": "multi_worker_with_ctl.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
